package cn.wzl.perceptron.perceptron;

import cn.wzl.perceptron.gif.ScalableXyCoordinateSystem;
import cn.wzl.perceptron.utils.ClassTransferUtil;
import com.madgag.gif.fmsware.AnimatedGifEncoder;

import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class AndPerceptron extends SingleNodePerceptron {

    private final boolean isDebugMode;

    private int debugPictureId = 0;
    private final List<BufferedImage> debugImages = new ArrayList<>();

    public AndPerceptron(boolean isDebugMode) {
        super(2, 0.2);
        this.isDebugMode = isDebugMode;
    }

    public void training(boolean[] x, boolean expectedValue) {
        double e = ClassTransferUtil.transferBooleanToDouble(expectedValue);
        double a = super.training(
                ClassTransferUtil.transferBooleansToDoubles(x),
                e
        );

        if(isDebugMode) {
            System.out.println("w:" + Arrays.toString(w) + " b:" + b + " dw:" + Arrays.toString(dw) + " a:" + a + " e:" + e);
            write2DStatusToPictureFile("output/image/" + debugPictureId + ".jpg", debugPictureId);
            debugPictureId++;
        }
    }

    public double run(boolean[] x) {
        return super.run(ClassTransferUtil.transferBooleansToDoubles(x));
    }

    @Override
    protected double loss(double expectedValue, double actualValue) {
        return (expectedValue - actualValue);
    }

    // e = 1 a = 0 -> weightedValue<0 dy = -1
    // e = 0 a = 1 -> weightedValue>0 dy = 1
    @Override
    protected double dActive(double dy) {
        //return weightedValue*dy; 梯度消失
        return dy;
    }

    @Override
    protected double active(double out) {
        if(out >= 0) {
            return 1.0;
        }
        else{
            return 0.0;
        }
    }

    public void generateDebugGif(String fileName) {
        if(isDebugMode && !debugImages.isEmpty()) {
            AnimatedGifEncoder gifEncoder = new AnimatedGifEncoder();
            gifEncoder.setRepeat(0);
            gifEncoder.start(fileName);
            debugImages.forEach(image -> {
                gifEncoder.setDelay(500);
                gifEncoder.addFrame(image);
            });
            gifEncoder.finish();
            debugImages.clear();
        }
    }

    private void write2DStatusToGif(BufferedImage bufferedImage) {
        debugImages.add(bufferedImage);
    }

    private void write2DStatusToPictureFile(String fileName, int debugPictureId) {
        ScalableXyCoordinateSystem scalableXyCoordinateSystem = new ScalableXyCoordinateSystem(5, 5, 50);
        drawReferencePoint(scalableXyCoordinateSystem);
        scalableXyCoordinateSystem.drawTitle(-4, 4, Integer.toString(debugPictureId));
        scalableXyCoordinateSystem.drawLine(w[0], w[1], b, Color.RED);
        write2DStatusToGif(scalableXyCoordinateSystem.retrieveBufferedImage());
        try {
            scalableXyCoordinateSystem.writeToFile(fileName);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void drawReferencePoint(ScalableXyCoordinateSystem scalableXyCoordinateSystem) {
        scalableXyCoordinateSystem.drawPoint(1, 1, Color.GREEN);
        scalableXyCoordinateSystem.drawPoint(1, 0, Color.BLACK);
        scalableXyCoordinateSystem.drawPoint(0, 1, Color.BLACK);
        scalableXyCoordinateSystem.drawPoint(0, 0, Color.BLACK);
    }

}
